{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "05_pytorch_going_modular_exercise_solutions.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyMV7Vpu90H5EsauLjBra4IS",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-learning/blob/main/extras/solutions/05_pytorch_going_modular_exercise_solutions.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 05. PyTorch Going Modular Exercise Solutions\n",
        "\n",
        "Welcome to the 05. PyTorch Going Modular exercise solutions notebook.\n",
        "\n",
        "> **Note:** There may be more than one solution to each of the exercises. This notebook only shows one possible example.\n",
        "\n",
        "## Resources\n",
        "\n",
        "1. These exercises/solutions are based on [section 05. PyTorch Going Modular](https://www.learnpytorch.io/05_pytorch_going_modular/) of the Learn PyTorch for Deep Learning course by Zero to Mastery.\n",
        "2. See a live [walkthrough of the solutions (errors and all) on YouTube](https://youtu.be/ijgFhMK3pp4).\n",
        "3. See [other solutions on the course GitHub](https://github.com/mrdbourke/pytorch-deep-learning/tree/main/extras/solutions)."
      ],
      "metadata": {
        "id": "zNqPNlYylluR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1. Turn the code to get the data (from section 1. Get Data) into a Python script, such as `get_data.py`.\n",
        "\n",
        "* When you run the script using `python get_data.py` it should check if the data already exists and skip downloading if it does.\n",
        "* If the data download is successful, you should be able to access the `pizza_steak_sushi` images from the `data` directory."
      ],
      "metadata": {
        "id": "bicbWSrPmfTU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile get_data.py\n",
        "import os\n",
        "import zipfile\n",
        "\n",
        "from pathlib import Path\n",
        "\n",
        "import requests\n",
        "\n",
        "# Setup path to data folder\n",
        "data_path = Path(\"data/\")\n",
        "image_path = data_path / \"pizza_steak_sushi\"\n",
        "\n",
        "# If the image folder doesn't exist, download it and prepare it... \n",
        "if image_path.is_dir():\n",
        "    print(f\"{image_path} directory exists.\")\n",
        "else:\n",
        "    print(f\"Did not find {image_path} directory, creating one...\")\n",
        "    image_path.mkdir(parents=True, exist_ok=True)\n",
        "    \n",
        "# Download pizza, steak, sushi data\n",
        "with open(data_path / \"pizza_steak_sushi.zip\", \"wb\") as f:\n",
        "    request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n",
        "    print(\"Downloading pizza, steak, sushi data...\")\n",
        "    f.write(request.content)\n",
        "\n",
        "# Unzip pizza, steak, sushi data\n",
        "with zipfile.ZipFile(data_path / \"pizza_steak_sushi.zip\", \"r\") as zip_ref:\n",
        "    print(\"Unzipping pizza, steak, sushi data...\") \n",
        "    zip_ref.extractall(image_path)\n",
        "\n",
        "# Remove zip file\n",
        "os.remove(data_path / \"pizza_steak_sushi.zip\")"
      ],
      "metadata": {
        "id": "D-mTTwPjmt4f",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b482f845-1f78-4b4f-826a-c0c050c1b6ef"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing get_data.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python get_data.py"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WAtJ-n10BWHZ",
        "outputId": "9c9b896c-a339-4d9a-f76d-69bce5261961"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Did not find data/pizza_steak_sushi directory, creating one...\n",
            "Downloading pizza, steak, sushi data...\n",
            "Unzipping pizza, steak, sushi data...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2. Use [Python's `argparse` module](https://docs.python.org/3/library/argparse.html) to be able to send the `train.py` custom hyperparameter values for training procedures.\n",
        "* Add an argument flag for using a different:\n",
        "  * Training/testing directory\n",
        "  * Learning rate\n",
        "  * Batch size\n",
        "  * Number of epochs to train for\n",
        "  * Number of hidden units in the TinyVGG model\n",
        "    * Keep the default values for each of the above arguments as what they already are (as in notebook 05).\n",
        "* For example, you should be able to run something similar to the following line to train a TinyVGG model with a learning rate of 0.003 and a batch size of 64 for 20 epochs: `python train.py --learning_rate 0.003 batch_size 64 num_epochs 20`.\n",
        "* **Note:** Since `train.py` leverages the other scripts we created in section 05, such as, `model_builder.py`, `utils.py` and `engine.py`, you'll have to make sure they're available to use too. You can find these in the [`going_modular` folder on the course GitHub](https://github.com/mrdbourke/pytorch-deep-learning/tree/main/going_modular/going_modular). "
      ],
      "metadata": {
        "id": "zjyn7LU3mvkR"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile data_setup.py\n",
        "\"\"\"\n",
        "Contains functionality for creating PyTorch DataLoaders for \n",
        "image classification data.\n",
        "\"\"\"\n",
        "import os\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "NUM_WORKERS = os.cpu_count()\n",
        "\n",
        "def create_dataloaders(\n",
        "    train_dir: str, \n",
        "    test_dir: str, \n",
        "    transform: transforms.Compose, \n",
        "    batch_size: int, \n",
        "    num_workers: int=NUM_WORKERS\n",
        "):\n",
        "  \"\"\"Creates training and testing DataLoaders.\n",
        "  Takes in a training directory and testing directory path and turns\n",
        "  them into PyTorch Datasets and then into PyTorch DataLoaders.\n",
        "  Args:\n",
        "    train_dir: Path to training directory.\n",
        "    test_dir: Path to testing directory.\n",
        "    transform: torchvision transforms to perform on training and testing data.\n",
        "    batch_size: Number of samples per batch in each of the DataLoaders.\n",
        "    num_workers: An integer for number of workers per DataLoader.\n",
        "  Returns:\n",
        "    A tuple of (train_dataloader, test_dataloader, class_names).\n",
        "    Where class_names is a list of the target classes.\n",
        "    Example usage:\n",
        "      train_dataloader, test_dataloader, class_names = \\\n",
        "        = create_dataloaders(train_dir=path/to/train_dir,\n",
        "                             test_dir=path/to/test_dir,\n",
        "                             transform=some_transform,\n",
        "                             batch_size=32,\n",
        "                             num_workers=4)\n",
        "  \"\"\"\n",
        "  # Use ImageFolder to create dataset(s)\n",
        "  train_data = datasets.ImageFolder(train_dir, transform=transform)\n",
        "  test_data = datasets.ImageFolder(test_dir, transform=transform)\n",
        "\n",
        "  # Get class names\n",
        "  class_names = train_data.classes\n",
        "\n",
        "  # Turn images into data loaders\n",
        "  train_dataloader = DataLoader(\n",
        "      train_data,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=True,\n",
        "      num_workers=num_workers,\n",
        "      pin_memory=True,\n",
        "  )\n",
        "  test_dataloader = DataLoader(\n",
        "      test_data,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=False,\n",
        "      num_workers=num_workers,\n",
        "      pin_memory=True,\n",
        "  )\n",
        "\n",
        "  return train_dataloader, test_dataloader, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QQ4GoYVpDg3g",
        "outputId": "11511586-1a10-43fd-b60b-81a9b32f5a8e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing data_setup.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile engine.py\n",
        "\"\"\"\n",
        "Contains functions for training and testing a PyTorch model.\n",
        "\"\"\"\n",
        "import torch\n",
        "\n",
        "from tqdm.auto import tqdm\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "def train_step(model: torch.nn.Module, \n",
        "               dataloader: torch.utils.data.DataLoader, \n",
        "               loss_fn: torch.nn.Module, \n",
        "               optimizer: torch.optim.Optimizer,\n",
        "               device: torch.device) -> Tuple[float, float]:\n",
        "    \"\"\"Trains a PyTorch model for a single epoch.\n",
        "    Turns a target PyTorch model to training mode and then\n",
        "    runs through all of the required training steps (forward\n",
        "    pass, loss calculation, optimizer step).\n",
        "    Args:\n",
        "    model: A PyTorch model to be trained.\n",
        "    dataloader: A DataLoader instance for the model to be trained on.\n",
        "    loss_fn: A PyTorch loss function to minimize.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "    Returns:\n",
        "    A tuple of training loss and training accuracy metrics.\n",
        "    In the form (train_loss, train_accuracy). For example:\n",
        "    (0.1112, 0.8743)\n",
        "    \"\"\"\n",
        "    # Put model in train mode\n",
        "    model.train()\n",
        "\n",
        "    # Setup train loss and train accuracy values\n",
        "    train_loss, train_acc = 0, 0\n",
        "\n",
        "    # Loop through data loader data batches\n",
        "    for batch, (X, y) in enumerate(dataloader):\n",
        "        # Send data to target device\n",
        "        X, y = X.to(device), y.to(device)\n",
        "\n",
        "        # 1. Forward pass\n",
        "        y_pred = model(X)\n",
        "\n",
        "        # 2. Calculate  and accumulate loss\n",
        "        loss = loss_fn(y_pred, y)\n",
        "        train_loss += loss.item() \n",
        "\n",
        "        # 3. Optimizer zero grad\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        # 4. Loss backward\n",
        "        loss.backward()\n",
        "\n",
        "        # 5. Optimizer step\n",
        "        optimizer.step()\n",
        "\n",
        "        # Calculate and accumulate accuracy metric across all batches\n",
        "        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)\n",
        "        train_acc += (y_pred_class == y).sum().item()/len(y_pred)\n",
        "\n",
        "    # Adjust metrics to get average loss and accuracy per batch \n",
        "    train_loss = train_loss / len(dataloader)\n",
        "    train_acc = train_acc / len(dataloader)\n",
        "    return train_loss, train_acc\n",
        "\n",
        "def test_step(model: torch.nn.Module, \n",
        "              dataloader: torch.utils.data.DataLoader, \n",
        "              loss_fn: torch.nn.Module,\n",
        "              device: torch.device) -> Tuple[float, float]:\n",
        "    \"\"\"Tests a PyTorch model for a single epoch.\n",
        "    Turns a target PyTorch model to \"eval\" mode and then performs\n",
        "    a forward pass on a testing dataset.\n",
        "    Args:\n",
        "    model: A PyTorch model to be tested.\n",
        "    dataloader: A DataLoader instance for the model to be tested on.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on the test data.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "    Returns:\n",
        "    A tuple of testing loss and testing accuracy metrics.\n",
        "    In the form (test_loss, test_accuracy). For example:\n",
        "    (0.0223, 0.8985)\n",
        "    \"\"\"\n",
        "    # Put model in eval mode\n",
        "    model.eval() \n",
        "\n",
        "    # Setup test loss and test accuracy values\n",
        "    test_loss, test_acc = 0, 0\n",
        "\n",
        "    # Turn on inference context manager\n",
        "    with torch.inference_mode():\n",
        "        # Loop through DataLoader batches\n",
        "        for batch, (X, y) in enumerate(dataloader):\n",
        "            # Send data to target device\n",
        "            X, y = X.to(device), y.to(device)\n",
        "\n",
        "            # 1. Forward pass\n",
        "            test_pred_logits = model(X)\n",
        "\n",
        "            # 2. Calculate and accumulate loss\n",
        "            loss = loss_fn(test_pred_logits, y)\n",
        "            test_loss += loss.item()\n",
        "\n",
        "            # Calculate and accumulate accuracy\n",
        "            test_pred_labels = test_pred_logits.argmax(dim=1)\n",
        "            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))\n",
        "\n",
        "    # Adjust metrics to get average loss and accuracy per batch \n",
        "    test_loss = test_loss / len(dataloader)\n",
        "    test_acc = test_acc / len(dataloader)\n",
        "    return test_loss, test_acc\n",
        "\n",
        "def train(model: torch.nn.Module, \n",
        "          train_dataloader: torch.utils.data.DataLoader, \n",
        "          test_dataloader: torch.utils.data.DataLoader, \n",
        "          optimizer: torch.optim.Optimizer,\n",
        "          loss_fn: torch.nn.Module,\n",
        "          epochs: int,\n",
        "          device: torch.device) -> Dict[str, List]:\n",
        "    \"\"\"Trains and tests a PyTorch model.\n",
        "    Passes a target PyTorch models through train_step() and test_step()\n",
        "    functions for a number of epochs, training and testing the model\n",
        "    in the same epoch loop.\n",
        "    Calculates, prints and stores evaluation metrics throughout.\n",
        "    Args:\n",
        "    model: A PyTorch model to be trained and tested.\n",
        "    train_dataloader: A DataLoader instance for the model to be trained on.\n",
        "    test_dataloader: A DataLoader instance for the model to be tested on.\n",
        "    optimizer: A PyTorch optimizer to help minimize the loss function.\n",
        "    loss_fn: A PyTorch loss function to calculate loss on both datasets.\n",
        "    epochs: An integer indicating how many epochs to train for.\n",
        "    device: A target device to compute on (e.g. \"cuda\" or \"cpu\").\n",
        "    Returns:\n",
        "    A dictionary of training and testing loss as well as training and\n",
        "    testing accuracy metrics. Each metric has a value in a list for \n",
        "    each epoch.\n",
        "    In the form: {train_loss: [...],\n",
        "              train_acc: [...],\n",
        "              test_loss: [...],\n",
        "              test_acc: [...]} \n",
        "    For example if training for epochs=2: \n",
        "             {train_loss: [2.0616, 1.0537],\n",
        "              train_acc: [0.3945, 0.3945],\n",
        "              test_loss: [1.2641, 1.5706],\n",
        "              test_acc: [0.3400, 0.2973]} \n",
        "    \"\"\"\n",
        "    # Create empty results dictionary\n",
        "    results = {\"train_loss\": [],\n",
        "               \"train_acc\": [],\n",
        "               \"test_loss\": [],\n",
        "               \"test_acc\": []\n",
        "    }\n",
        "\n",
        "    # Loop through training and testing steps for a number of epochs\n",
        "    for epoch in tqdm(range(epochs)):\n",
        "        train_loss, train_acc = train_step(model=model,\n",
        "                                          dataloader=train_dataloader,\n",
        "                                          loss_fn=loss_fn,\n",
        "                                          optimizer=optimizer,\n",
        "                                          device=device)\n",
        "        test_loss, test_acc = test_step(model=model,\n",
        "          dataloader=test_dataloader,\n",
        "          loss_fn=loss_fn,\n",
        "          device=device)\n",
        "\n",
        "        # Print out what's happening\n",
        "        print(\n",
        "          f\"Epoch: {epoch+1} | \"\n",
        "          f\"train_loss: {train_loss:.4f} | \"\n",
        "          f\"train_acc: {train_acc:.4f} | \"\n",
        "          f\"test_loss: {test_loss:.4f} | \"\n",
        "          f\"test_acc: {test_acc:.4f}\"\n",
        "        )\n",
        "\n",
        "        # Update results dictionary\n",
        "        results[\"train_loss\"].append(train_loss)\n",
        "        results[\"train_acc\"].append(train_acc)\n",
        "        results[\"test_loss\"].append(test_loss)\n",
        "        results[\"test_acc\"].append(test_acc)\n",
        "\n",
        "    # Return the filled results at the end of the epochs\n",
        "    return results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EzD2EHmsDqMg",
        "outputId": "8c15e6f3-14cf-444e-c246-defed3e81a29"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing engine.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile model_builder.py\n",
        "\"\"\"\n",
        "Contains PyTorch model code to instantiate a TinyVGG model.\n",
        "\"\"\"\n",
        "import torch\n",
        "from torch import nn \n",
        "\n",
        "class TinyVGG(nn.Module):\n",
        "    \"\"\"Creates the TinyVGG architecture.\n",
        "    Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.\n",
        "    See the original architecture here: https://poloclub.github.io/cnn-explainer/\n",
        "    Args:\n",
        "    input_shape: An integer indicating number of input channels.\n",
        "    hidden_units: An integer indicating number of hidden units between layers.\n",
        "    output_shape: An integer indicating number of output units.\n",
        "    \"\"\"\n",
        "    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:\n",
        "        super().__init__()\n",
        "        self.conv_block_1 = nn.Sequential(\n",
        "          nn.Conv2d(in_channels=input_shape, \n",
        "                    out_channels=hidden_units, \n",
        "                    kernel_size=3, \n",
        "                    stride=1, \n",
        "                    padding=0),  \n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(in_channels=hidden_units, \n",
        "                    out_channels=hidden_units,\n",
        "                    kernel_size=3,\n",
        "                    stride=1,\n",
        "                    padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(kernel_size=2,\n",
        "                        stride=2)\n",
        "        )\n",
        "        self.conv_block_2 = nn.Sequential(\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),\n",
        "          nn.ReLU(),\n",
        "          nn.MaxPool2d(2)\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "          nn.Flatten(),\n",
        "          # Where did this in_features shape come from? \n",
        "          # It's because each layer of our network compresses and changes the shape of our inputs data.\n",
        "          nn.Linear(in_features=hidden_units*13*13,\n",
        "                    out_features=output_shape)\n",
        "        )\n",
        "    \n",
        "    def forward(self, x: torch.Tensor):\n",
        "        x = self.conv_block_1(x)\n",
        "        x = self.conv_block_2(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "        # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "V0jywo2jDx-d",
        "outputId": "ea63d29d-e4ce-4057-da82-5ea92c1a1e95"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing model_builder.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile utils.py\n",
        "\"\"\"\n",
        "Contains various utility functions for PyTorch model training and saving.\n",
        "\"\"\"\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "def save_model(model: torch.nn.Module,\n",
        "               target_dir: str,\n",
        "               model_name: str):\n",
        "    \"\"\"Saves a PyTorch model to a target directory.\n",
        "    Args:\n",
        "    model: A target PyTorch model to save.\n",
        "    target_dir: A directory for saving the model to.\n",
        "    model_name: A filename for the saved model. Should include\n",
        "      either \".pth\" or \".pt\" as the file extension.\n",
        "    Example usage:\n",
        "    save_model(model=model_0,\n",
        "               target_dir=\"models\",\n",
        "               model_name=\"05_going_modular_tingvgg_model.pth\")\n",
        "    \"\"\"\n",
        "    # Create target directory\n",
        "    target_dir_path = Path(target_dir)\n",
        "    target_dir_path.mkdir(parents=True,\n",
        "                        exist_ok=True)\n",
        "\n",
        "    # Create model save path\n",
        "    assert model_name.endswith(\".pth\") or model_name.endswith(\".pt\"), \"model_name should end with '.pt' or '.pth'\"\n",
        "    model_save_path = target_dir_path / model_name\n",
        "\n",
        "    # Save the model state_dict()\n",
        "    print(f\"[INFO] Saving model to: {model_save_path}\")\n",
        "    torch.save(obj=model.state_dict(),\n",
        "             f=model_save_path)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wk7FsALmD1pI",
        "outputId": "344661fd-f1b7-4cde-8f65-5c3b707023e2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing utils.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile train.py\n",
        "\"\"\"\n",
        "Trains a PyTorch image classification model using device-agnostic code.\n",
        "\"\"\"\n",
        "\n",
        "import os\n",
        "import argparse\n",
        "\n",
        "import torch\n",
        "\n",
        "from torchvision import transforms\n",
        "\n",
        "import data_setup, engine, model_builder, utils\n",
        "\n",
        "# Create a parser\n",
        "parser = argparse.ArgumentParser(description=\"Get some hyperparameters.\")\n",
        "\n",
        "# Get an arg for num_epochs\n",
        "parser.add_argument(\"--num_epochs\", \n",
        "                     default=10, \n",
        "                     type=int, \n",
        "                     help=\"the number of epochs to train for\")\n",
        "\n",
        "# Get an arg for batch_size\n",
        "parser.add_argument(\"--batch_size\",\n",
        "                    default=32,\n",
        "                    type=int,\n",
        "                    help=\"number of samples per batch\")\n",
        "\n",
        "# Get an arg for hidden_units\n",
        "parser.add_argument(\"--hidden_units\",\n",
        "                    default=10,\n",
        "                    type=int,\n",
        "                    help=\"number of hidden units in hidden layers\")\n",
        "\n",
        "# Get an arg for learning_rate\n",
        "parser.add_argument(\"--learning_rate\",\n",
        "                    default=0.001,\n",
        "                    type=float,\n",
        "                    help=\"learning rate to use for model\")\n",
        "\n",
        "# Create an arg for training directory \n",
        "parser.add_argument(\"--train_dir\",\n",
        "                    default=\"data/pizza_steak_sushi/train\",\n",
        "                    type=str,\n",
        "                    help=\"directory file path to training data in standard image classification format\")\n",
        "\n",
        "# Create an arg for test directory \n",
        "parser.add_argument(\"--test_dir\",\n",
        "                    default=\"data/pizza_steak_sushi/test\",\n",
        "                    type=str,\n",
        "                    help=\"directory file path to testing data in standard image classification format\")\n",
        "\n",
        "# Get our arguments from the parser\n",
        "args = parser.parse_args()\n",
        "\n",
        "# Setup hyperparameters\n",
        "NUM_EPOCHS = args.num_epochs\n",
        "BATCH_SIZE = args.batch_size\n",
        "HIDDEN_UNITS = args.hidden_units\n",
        "LEARNING_RATE = args.learning_rate\n",
        "print(f\"[INFO] Training a model for {NUM_EPOCHS} epochs with batch size {BATCH_SIZE} using {HIDDEN_UNITS} hidden units and a learning rate of {LEARNING_RATE}\")\n",
        "\n",
        "# Setup directories\n",
        "train_dir = args.train_dir\n",
        "test_dir = args.test_dir\n",
        "print(f\"[INFO] Training data file: {train_dir}\")\n",
        "print(f\"[INFO] Testing data file: {test_dir}\")\n",
        "\n",
        "# Setup target device\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Create transforms\n",
        "data_transform = transforms.Compose([\n",
        "  transforms.Resize((64, 64)),\n",
        "  transforms.ToTensor()\n",
        "])\n",
        "\n",
        "# Create DataLoaders with help from data_setup.py\n",
        "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(\n",
        "    train_dir=train_dir,\n",
        "    test_dir=test_dir,\n",
        "    transform=data_transform,\n",
        "    batch_size=BATCH_SIZE\n",
        ")\n",
        "\n",
        "# Create model with help from model_builder.py\n",
        "model = model_builder.TinyVGG(\n",
        "    input_shape=3,\n",
        "    hidden_units=HIDDEN_UNITS,\n",
        "    output_shape=len(class_names)\n",
        ").to(device)\n",
        "\n",
        "# Set loss and optimizer\n",
        "loss_fn = torch.nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(model.parameters(),\n",
        "                             lr=LEARNING_RATE)\n",
        "\n",
        "# Start training with help from engine.py\n",
        "engine.train(model=model,\n",
        "             train_dataloader=train_dataloader,\n",
        "             test_dataloader=test_dataloader,\n",
        "             loss_fn=loss_fn,\n",
        "             optimizer=optimizer,\n",
        "             epochs=NUM_EPOCHS,\n",
        "             device=device)\n",
        "\n",
        "# Save the model with help from utils.py\n",
        "utils.save_model(model=model,\n",
        "                 target_dir=\"models\",\n",
        "                 model_name=\"05_going_modular_script_mode_tinyvgg_model.pth\")"
      ],
      "metadata": {
        "id": "wYVTSZ5sm02-",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "aa595a1e-af61-4294-dcbd-6ebee5849fec"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing train.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python train.py --num_epochs 5 --batch_size 128 --hidden_units 128 --learning_rate 0.0003"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LzaJl39lC40N",
        "outputId": "a5cf39b4-b69c-4b32-9ce7-0f025a18d582"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[INFO] Training a model for 5 epochs with batch size 128 using 128 hidden units and a learning rate of 0.0003\n",
            "[INFO] Training data file: data/pizza_steak_sushi/train\n",
            "[INFO] Testing data file: data/pizza_steak_sushi/test\n",
            "  0% 0/5 [00:00<?, ?it/s]Epoch: 1 | train_loss: 1.0994 | train_acc: 0.3082 | test_loss: 1.0933 | test_acc: 0.3333\n",
            " 20% 1/5 [00:01<00:07,  1.94s/it]Epoch: 2 | train_loss: 1.0897 | train_acc: 0.3859 | test_loss: 1.0803 | test_acc: 0.3733\n",
            " 40% 2/5 [00:03<00:04,  1.65s/it]Epoch: 3 | train_loss: 1.0683 | train_acc: 0.3926 | test_loss: 1.0466 | test_acc: 0.4800\n",
            " 60% 3/5 [00:04<00:03,  1.56s/it]Epoch: 4 | train_loss: 1.0318 | train_acc: 0.4898 | test_loss: 1.0320 | test_acc: 0.4267\n",
            " 80% 4/5 [00:06<00:01,  1.55s/it]Epoch: 5 | train_loss: 0.9779 | train_acc: 0.5560 | test_loss: 1.0151 | test_acc: 0.4000\n",
            "100% 5/5 [00:07<00:00,  1.57s/it]\n",
            "[INFO] Saving model to: models/05_going_modular_script_mode_tinyvgg_model.pth\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3. Create a Python script to predict (such as `predict.py`) on a target image given a file path with a saved model.\n",
        "\n",
        "* For example, you should be able to run the command `python predict.py some_image.jpeg` and have a trained PyTorch model predict on the image and return its prediction.\n",
        "* To see example prediction code, check out the [predicting on a custom image section in notebook 04](https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function). \n",
        "* You may also have to write code to load in a trained model."
      ],
      "metadata": {
        "id": "P2g6EEYvm-46"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%writefile predict.py\n",
        "import torch\n",
        "import torchvision\n",
        "import argparse\n",
        "\n",
        "import model_builder\n",
        "\n",
        "# Creating a parser\n",
        "parser = argparse.ArgumentParser()\n",
        "\n",
        "# Get an image path\n",
        "parser.add_argument(\"--image\",\n",
        "                    help=\"target image filepath to predict on\")\n",
        "\n",
        "# Get a model path\n",
        "parser.add_argument(\"--model_path\",\n",
        "                    default=\"models/05_going_modular_script_mode_tinyvgg_model.pth\",\n",
        "                    type=str,\n",
        "                    help=\"target model to use for prediction filepath\")\n",
        "\n",
        "args = parser.parse_args()\n",
        "\n",
        "# Setup class names\n",
        "class_names = [\"pizza\", \"steak\", \"sushi\"]\n",
        "\n",
        "# Setup device\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Get the image path\n",
        "IMG_PATH = args.image\n",
        "print(f\"[INFO] Predicting on {IMG_PATH}\")\n",
        "\n",
        "# Function to load in the model\n",
        "def load_model(filepath=args.model_path):\n",
        "  # Need to use same hyperparameters as saved model \n",
        "  model = model_builder.TinyVGG(input_shape=3,\n",
        "                                hidden_units=128,\n",
        "                                output_shape=3).to(device)\n",
        "\n",
        "  print(f\"[INFO] Loading in model from: {filepath}\")\n",
        "  # Load in the saved model state dictionary from file                               \n",
        "  model.load_state_dict(torch.load(filepath))\n",
        "\n",
        "  return model\n",
        "\n",
        "# Function to load in model + predict on select image\n",
        "def predict_on_image(image_path=IMG_PATH, filepath=args.model_path):\n",
        "  # Load the model\n",
        "  model = load_model(filepath)\n",
        "\n",
        "  # Load in the image and turn it into torch.float32 (same type as model)\n",
        "  image = torchvision.io.read_image(str(IMG_PATH)).type(torch.float32)\n",
        "\n",
        "  # Preprocess the image to get it between 0 and 1\n",
        "  image = image / 255.\n",
        "\n",
        "  # Resize the image to be the same size as the model\n",
        "  transform = torchvision.transforms.Resize(size=(64, 64))\n",
        "  image = transform(image) \n",
        "\n",
        "  # Predict on image\n",
        "  model.eval()\n",
        "  with torch.inference_mode():\n",
        "    # Put image to target device\n",
        "    image = image.to(device)\n",
        "\n",
        "    # Get pred logits\n",
        "    pred_logits = model(image.unsqueeze(dim=0)) # make sure image has batch dimension (shape: [batch_size, height, width, color_channels])\n",
        "\n",
        "    # Get pred probs\n",
        "    pred_prob = torch.softmax(pred_logits, dim=1)\n",
        "\n",
        "    # Get pred labels\n",
        "    pred_label = torch.argmax(pred_prob, dim=1)\n",
        "    pred_label_class = class_names[pred_label]\n",
        "\n",
        "  print(f\"[INFO] Pred class: {pred_label_class}, Pred prob: {pred_prob.max():.3f}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  predict_on_image()"
      ],
      "metadata": {
        "id": "0CyTzC3tnCon",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3baa1cd0-ebb3-410b-f6e5-ef50f9ad17a2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Writing predict.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python predict.py --image data/pizza_steak_sushi/test/sushi/175783.jpg"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Zcvw9sitIn6r",
        "outputId": "dd388c00-a10f-4c14-bdd3-51f507e20a12"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[INFO] Predicting on data/pizza_steak_sushi/test/sushi/175783.jpg\n",
            "[INFO] Loading in model from: models/05_going_modular_script_mode_tinyvgg_model.pth\n",
            "[INFO] Pred class: steak, Pred prob: 0.376\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "nwmoMhW8IqSu"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}